//===- epilogue_vectorized_reduce_scatter.hpp --------------------- C++ ---===//
//
// Copyright 2025 ByteDance Ltd. and/or its affiliates. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//    http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//===----------------------------------------------------------------------===//
// Some code from cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp
// in NVIDIA cutlass project
// Original license as follows
/***************************************************************************************************
 * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: BSD-3-Clause
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 * 1. Redistributions of source code must retain the above copyright notice, this
 * list of conditions and the following disclaimer.
 *
 * 2. Redistributions in binary form must reproduce the above copyright notice,
 * this list of conditions and the following disclaimer in the documentation
 * and/or other materials provided with the distribution.
 *
 * 3. Neither the name of the copyright holder nor the names of its
 * contributors may be used to endorse or promote products derived from
 * this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 **************************************************************************************************/
/*! \file
  \brief Functor performing elementwise operations used by epilogues.
*/

#pragma once

#include "cutlass/cutlass.h"

#include "cute/tensor.hpp"
#include "flux/cuda/cuda_common.h"

/////////////////////////////////////////////////////////////////////////////////////////////////

namespace cutlass {
namespace epilogue {
namespace collective {

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Applies an element wise operation to all elements within the fragment
/// and writes it out to destination storage.
///
/// Ways to generalize this:
/// - CTA tile shape
/// - vectorization requirements (GMEM)
/// - vectoriz(able) transform()
///
template <
    class StrideC_,
    class StrideD_,
    class ThreadEpilogueOp_,
    class SmemLayout_,
    class CopyAtomR2S_,
    class TiledCopyS2R_,
    class CopyAtomR2G_>
class EpilogueReduceScatterVectorized {
 public:
  //
  // Type Aliases
  //
  // derived types of output thread level operator
  using ThreadEpilogueOp = ThreadEpilogueOp_;
  using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator;
  using ElementCompute = typename ThreadEpilogueOp::ElementCompute;
  using ElementScalar = ElementCompute;
  using ElementOutput = typename ThreadEpilogueOp::ElementOutput;
  using ElementC = typename ThreadEpilogueOp::ElementC;
  using StrideC = StrideC_;
  using ElementD = typename ThreadEpilogueOp::ElementD;
  using StrideD = StrideD_;

  using SmemLayout = SmemLayout_;
  using CopyAtomR2S = CopyAtomR2S_;
  using TiledCopyS2R = TiledCopyS2R_;
  using CopyAtomR2G = CopyAtomR2G_;

  static const int kOutputAlignment = ThreadEpilogueOp::kCount;
  using AlignmentType =
      typename cute::uint_bit<sizeof_bits<ElementOutput>::value * kOutputAlignment>::type;

  static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]");
  static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]");

  struct SharedStorage {
    cute::array_aligned<ElementAccumulator, cute::cosize_v<SmemLayout>> smem_epilogue;
  };

  // Host side epilogue arguments
  struct Arguments {
    int32_t rank;
    int32_t world_size;
    typename ThreadEpilogueOp::Params thread{};
    ElementC const *ptr_C = nullptr;
    StrideC dC{};
    ElementD **scatter_ptr_D = nullptr;
    StrideD dD{};
    int **barrier_ptr_aux = nullptr;
  };

  // Device side epilogue params
  using Params = Arguments;

  //
  // Methods
  //

  template <class ProblemShape>
  static constexpr Params
  to_underlying_arguments(
      [[maybe_unused]] ProblemShape const &_,
      Arguments const &args,
      [[maybe_unused]] void *workspace) {
    return args;
  }

  template <class ProblemShape>
  CUTLASS_HOST_DEVICE static bool
  can_implement(
      [[maybe_unused]] ProblemShape const &problem_shape, [[maybe_unused]] Arguments const &args) {
    return true;
  }

  CUTLASS_HOST_DEVICE
  EpilogueReduceScatterVectorized(Params const &params_)
      : params(params_), epilogue_op(params_.thread) {}

  CUTLASS_DEVICE
  bool
  is_source_needed() {
    return epilogue_op.is_source_needed();
  }

  template <
      class ProblemShapeMNKL,
      class BlockShapeMNK,
      class BlockCoordMNKL,
      class FrgEngine,
      class FrgLayout,
      class TiledMma,
      class ResidueMNK>
  CUTLASS_DEVICE void
  operator()(
      ProblemShapeMNKL problem_shape_mnkl,
      BlockShapeMNK blk_shape_MNK,
      BlockCoordMNKL blk_coord_mnkl,
      cute::Tensor<FrgEngine, FrgLayout> const &accumulators,  // (MMA,MMA_M,MMA_N)
      TiledMma tiled_mma,
      ResidueMNK residue_mnk,
      int thread_idx,
      char *smem_buf) {
    using namespace cute;
    using X = Underscore;

    static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
    static_assert(is_static<BlockShapeMNK>::value, "ThreadBlock tile shape must be static");
    static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3");
    static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3");

    // synchronizing function for smem reads/writes
#if CUDA_BARRIER_ENABLED
    auto synchronize = []() {
      cutlass::arch::NamedBarrier::sync(typename TiledCopyS2R::TiledNumThr{}, 0);
    };
#else
    auto synchronize = []() { __syncthreads(); };
#endif

    // Separate out problem shape for convenience
    auto M = get<0>(problem_shape_mnkl);
    auto N = get<1>(problem_shape_mnkl);
    auto L = get<3>(problem_shape_mnkl);

    // Represent the full output tensor
    Tensor mC_mnl = make_tensor(
        make_gmem_ptr(params.ptr_C),
        make_shape(M, N, L),
        params.dC);  //             (m,n,l)
    Tensor gC_mnl = local_tile(
        mC_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{});  // (BLK_M,BLK_N,m,n,l)

    // Slice to get the tile this CTA is responsible for
    auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl;
    Tensor gC = gC_mnl(_, _, m_coord, n_coord, l_coord);  // (BLK_M,BLK_N)

    // Construct a tensor in SMEM that we can partition for rearranging data
    SharedStorage &storage = *reinterpret_cast<SharedStorage *>(smem_buf);
    Tensor sC =
        make_tensor(make_smem_ptr(storage.smem_epilogue.data()), SmemLayout{});  // (SMEM_M,SMEM_N)

    // Partition sC to match the accumulator partitioning
    auto tiled_r2s = make_tiled_copy_C(CopyAtomR2S{}, tiled_mma);
    auto tC = tiled_r2s.get_thread_slice(thread_idx);
    Tensor tCaC = tC.retile_S(accumulators);  // ((Atom,AtomNum), MMA_M, MMA_N)
    Tensor tCsC = tC.partition_D(sC);         // ((Atom,AtomNum),PIPE_M,PIPE_N)

    // Tile gD and gC by the shape of SmemLayout first
    auto tile = make_shape(size<0>(sC), size<1>(sC));
    Tensor gCt = flat_divide(gC, tile);  // (SMEM_M,SMEM_N,TILE_M,TILE_N)

    // Partition sC, gC, and gD for the output
    auto tiled_s2r = TiledCopyS2R{};
    auto tD = tiled_s2r.get_thread_slice(thread_idx);
    Tensor tDsC = tD.partition_S(sC);   //               ((Atom,AtomNum),ATOM_M,ATOM_N)
    Tensor tDgC = tD.partition_D(gCt);  // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N)

    // Allocate intermediate registers on the dst tensors
    Tensor tDrC = make_tensor<ElementAccumulator>(
        take<0, 3>(shape(tDgC)));                           // ((Atom,AtomNum),ATOM_M,ATOM_N)
    Tensor tDrD = make_tensor<ElementOutput>(shape(tDrC));  // ((Atom,AtomNum),ATOM_M,ATOM_N)

    // Repeat the D-partitioning for coordinates and predication
    Tensor cD = make_identity_tensor(
        make_shape(size<0>(gC), size<1>(gC)));  // (BLK_M,BLK_N) -> (blk_m,blk_n)
    Tensor cDt = flat_divide(cD, tile);         //                (SMEM_M,SMEM_N,TILE_M,TILE_N)
    Tensor tDcD = tD.partition_D(cDt);          // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N)

    CUTE_STATIC_ASSERT(size<1>(tCaC) % size<3>(tDgC) == 0);  // TILE_M divides MMA_M
    CUTE_STATIC_ASSERT(size<2>(tCaC) % size<4>(tDgC) == 0);  // TILE_N divides MMA_N
    // bytedance::flux::StaticPrintInt<typename TiledCopyS2R::TiledNumThr{}>{};
    // bytedance::flux::StaticPrintInt<size<0>(typename TiledMma::AtomLayoutC_TV{})>{};
    CUTE_STATIC_ASSERT(
        typename TiledCopyS2R::TiledNumThr{} == size<0>(typename TiledMma::AtomLayoutC_TV{}));

    int tp_shard_M = M / this->params.world_size;

    // mapping element to dst rank
    auto scatter_layout = make_layout(
        make_shape(make_shape(tp_shard_M, this->params.world_size), N, L),
        make_stride(make_stride(Int<0>{}, Int<1>{}), Int<0>{}, Int<0>{}));
    // follow the same tiling procedure
    auto mScatter_mnl = make_counting_tensor(scatter_layout);
    auto gScatter_mnl =
        local_tile(mScatter_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{});
    auto gScatter = gScatter_mnl(_, _, m_coord, n_coord, l_coord);

    auto get_tDgD = [&](int dst_rank) {
      auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl;
      // add the offset so that each rank write to different range of data
      ptrdiff_t offset = tp_shard_M * N * (this->params.rank - dst_rank);
      Tensor mD_mnl = make_tensor(
          make_gmem_ptr(params.scatter_ptr_D[dst_rank]) + offset, make_shape(M, N, L), params.dD);

      Tensor gD_mnl = local_tile(
          mD_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{});  // (BLK_M,BLK_N,m,n,l)
      Tensor gD = gD_mnl(_, _, m_coord, n_coord, l_coord);
      Tensor gDt = flat_divide(gD, tile);  // (SMEM_M,SMEM_N,TILE_M,TILE_N)
      Tensor tDgD = tD.partition_D(gDt);   // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N)

      return tDgD;
    };
#if 0
    if (thread_idx == 0 && m_coord == 0 && n_coord == 0) {
      print("aC   : "); print(accumulators.layout()); print("\n");
      print("gC   : "); print(gC.layout()); print("\n");
      print("gD   : "); print(gD.layout()); print("\n");
      print("sC   : "); print(sC.layout()); print("\n");
      print("\n");
      print("tCsC : "); print(tCsC.layout()); print("\n");
      print("tCaC : "); print(tCaC.layout()); print("\n");
      print("\n");
      print("gDt  : "); print(gDt.layout()); print("\n");
      print("tDsC : "); print(tDsC.layout()); print("\n");
      print("tDrC : "); print(tDrC.layout()); print("\n");
      print("\n");
      print("tDrD : "); print(tDrD.layout()); print("\n");
      print("tDgC : "); print(tDgC.layout()); print("\n");
      print("tDgD : "); print(tDgD.layout()); print("\n");
      print("\n");
    }
#endif

    // For each tiling needed for SmemLayout to cover shape(gD)
    CUTLASS_PRAGMA_UNROLL
    for (int step_m = 0; step_m < size<2>(cDt); ++step_m) {
      CUTLASS_PRAGMA_UNROLL
      for (int step_n = 0; step_n < size<3>(cDt); ++step_n) {
        // Step 1. Copy to SMEM
        CUTLASS_PRAGMA_UNROLL
        for (int pipe_m = 0; pipe_m < size<1>(tCsC); ++pipe_m) {
          CUTLASS_PRAGMA_UNROLL
          for (int pipe_n = 0; pipe_n < size<2>(tCsC); ++pipe_n) {
            int mma_m = step_m * size<1>(tCsC) + pipe_m;
            int mma_n = step_n * size<2>(tCsC) + pipe_n;

            copy(tiled_r2s, tCaC(_, mma_m, mma_n), tCsC(_, pipe_m, pipe_n));
          }
        }

        // Step 2. Wait for SMEM writes to complete
        synchronize();

        // Step 3. Copy from SMEM into a fragment
        copy(tiled_s2r, tDsC, tDrC);

        // Step 4. Wait for SMEM reads to complete
        synchronize();

        Tensor tDcDmn = tDcD(_, _, _, step_m, step_n);

        if (epilogue_op.is_source_needed()) {
          // source is needed
          Tensor tDgCmn = tDgC(_, _, _, step_m, step_n);
          CUTLASS_PRAGMA_UNROLL
          for (int m = 0; m < size<1>(tDcDmn); ++m) {
            CUTLASS_PRAGMA_UNROLL
            for (int n = 0; n < size<2>(tDcDmn); ++n) {
              // Predication
              if (get<0>(tDcDmn(0, m, n)) < get<0>(residue_mnk) &&
                  get<1>(tDcDmn(0, m, n)) < get<1>(residue_mnk)) {
                // Step 5. Elementwise operation with conversion
                CUTLASS_PRAGMA_UNROLL
                for (int i = 0; i < size<0>(tDrC); ++i) {
                  tDrD(i, m, n) = epilogue_op(tDrC(i, m, n), tDgCmn(i, m, n));
                }
                int dst_rank = gScatter(get<0>(tDcDmn(0, m, n)), get<1>(tDcDmn(0, m, n)));
                auto tDgD = get_tDgD(dst_rank);
                Tensor tDgDmn = tDgD(_, _, _, step_m, step_n);

                // Step 6. Copy to GMEM
                copy(CopyAtomR2G{}, tDrD(_, m, n), tDgDmn(_, m, n));
              }
            }
          }
        } else {
          // source is not needed, avoid load and lift compute

          // Step 5. Elementwise operation with conversion
          CUTLASS_PRAGMA_UNROLL
          for (int i = 0; i < size(tDrC); ++i) {
            tDrD(i) = epilogue_op(tDrC(i));
          }

          CUTLASS_PRAGMA_UNROLL
          for (int m = 0; m < size<1>(tDcDmn); ++m) {
            CUTLASS_PRAGMA_UNROLL
            for (int n = 0; n < size<2>(tDcDmn); ++n) {
              // Predication
              if (get<0>(tDcDmn(0, m, n)) < get<0>(residue_mnk) &&
                  get<1>(tDcDmn(0, m, n)) < get<1>(residue_mnk)) {
                int dst_rank = gScatter(get<0>(tDcDmn(0, m, n)), get<1>(tDcDmn(0, m, n)));
                auto tDgD = get_tDgD(dst_rank);
                Tensor tDgDmn = tDgD(_, _, _, step_m, step_n);
                // Step 6. Copy to GMEM
                copy(CopyAtomR2G{}, tDrD(_, m, n), tDgDmn(_, m, n));
              }
            }
          }
        }
      }
    }
  }

 private:
  Params params;
  ThreadEpilogueOp epilogue_op;
};

/////////////////////////////////////////////////////////////////////////////////////////////////

}  // namespace collective
}  // namespace epilogue
}  // namespace cutlass

/////////////////////////////////////////////////////////////////////////////////////////////////
